'''
Author: SlytherinGe
LastEditTime: 2021-06-21 10:07:03
'''
from mmdet.apis import init_detector, inference_detector, show_result_pyplot
import mmcv
import os

config_file = 'configs/my_configs/ssdd/focusfcos_hrnet.py'
# download the checkpoint from model zoo and put it in `checkpoints/`
# url: http://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth
checkpoint_file = '/media/gejunyao/Disk/Gejunyao/exp_results/mmdetection_files/SSDD/focus/functional_test3/epoch_24.pth'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# model = init_detector(config_file, checkpoint_file, device='cpu')

# img = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/000762.jpg'
# img = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/001112.jpg'
# img = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/000231.jpg'
# img = '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/000395.jpg'

test_imgs = ['/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/000762.jpg',
             '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/001112.jpg',
             '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/000231.jpg',
             '/media/gejunyao/Disk1/Datasets/SSDD/VOC2012/JPEGImages/000395.jpg']

PAOI_SAVE_FOLDER = '/media/gejunyao/Disk/Gejunyao/exp_results/visualization/middle_part/paoi_feats/'


if __name__ == '__main__':

    for img in test_imgs:
        result = inference_detector(model, [img])

        if PAOI_SAVE_FOLDER is not None:
            file_name = os.path.basename(img).split('.')[0]
            mmcv.dump(result, os.path.join(PAOI_SAVE_FOLDER, file_name+'.pkl'))

        # show the results
        # show_result_pyplot(model, img, result)